import torch
from transformers import BertModel, BertTokenizer, BertConfig

# # 首先要import进来
# tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# # config = BertConfig.from_pretrained('bert-base-chinese')
# # config.update({'output_hidden_states': True})  # 这里直接更改模型配置
# # model = BertModel.from_pretrained("bert-base-chinese", config=config)
#
# # 上文的示例代码已经实例话了，这里不重复了；
# print(tokenizer.encode("生活的真谛是美和爱"))  # 对于单个句子编码
# print(tokenizer.encode_plus("生活的真谛是美和爱", "说的太好了"))  # 对于一组句子编码
# # 输出结果如下：
# '''
# [101, 4495, 3833, 4638, 4696, 6465, 3221, 5401, 1469, 4263, 102]
# {
#  'input_ids': [101, 4495, 3833, 4638, 4696, 6465, 3221, 5401, 1469, 4263, 102, 6432, 4638, 1922, 1962, 749, 102],
#  'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,1, 1, 1, 1, 1],
#  'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# }
# '''
#
# # 也可以直接这样用
# sentences = ['网络安全开发分为三个层级',
#              '车辆系统层级网络安全开发',
#              '车辆功能层级网络安全开发',
#              '车辆零部件层级网络安全开发',
#              '测试团队根据车辆网络安全目标制定测试技术要求及测试计划',
#              '测试团队在网络安全团队的支持下，完成确认测试并编制测试报告',
#              '在车辆确认结果的基础上，基于合理的理由，确认在设计和开发阶段识别出的所有风险均已被接受', ]
# test1 = tokenizer(sentences)
#
# print(test1)  # 对列表encoder
# '''
# {'input_ids': [[101, 5381, 5317, 2128, 1059, 2458, 1355, 1146, 711, 676, 702, 2231, 5277, 102],
# [101, 6756, 6775, 5143, 5320, 2231, 5277, 5381, 5317, 2128, 1059, 2458, 1355, 102],
# [101, 6756, 6775, 1216, 5543, 2231, 5277, 5381, 5317, 2128, 1059, 2458, 1355, 102],
#  [101, 6756, 6775, 7439, 6956, 816, 2231, 5277, 5381, 5317, 2128, 1059, 2458, 1355, 102],
#  [101, 3844, 6407, 1730, 7339, 3418, 2945, 6756, 6775, 5381, 5317, 2128, 1059, 4680, 3403, 1169, 2137, 3844, 6407, 2825, 3318, 6206, 3724, 1350, 3844, 6407, 6369, 1153, 102],
#  [101, 3844, 6407, 1730, 7339, 1762, 5381, 5317, 2128, 1059, 1730, 7339, 4638, 3118, 2898, 678, 8024, 2130, 2768, 4802, 6371, 3844, 6407, 2400, 5356, 1169, 3844, 6407, 2845, 1440, 102],
#  [101, 1762, 6756, 6775, 4802, 6371, 5310, 3362, 4638, 1825, 4794, 677, 8024, 1825, 754, 1394, 4415, 4638, 4415, 4507, 8024, 4802, 6371, 1762, 6392, 6369, 1469, 2458, 1355, 7348, 3667, 6399, 1166, 1139, 4638, 2792, 3300, 7599, 7372, 1772, 2347, 6158, 2970, 1358, 102]],
#  'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
#  'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
#
# '''
#
# print(tokenizer("网络安全开发分为三个层级"))  # 对单个句子encoder
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from transformers import Trainer, Seq2SeqTrainer
import transformers
from transformers import DataCollatorWithPadding
from transformers import TextGenerationPipeline
import torch
import numpy as np
import os, re
from tqdm import tqdm
import torch.nn as nn

# 模型名称
MODEL_NAME = "gpt2"

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token_id = 0

# 标签集
named_labels = ['neg', 'pos']

# 标签转 token_id
label_ids = [
    tokenizer.encode_plus(named_labels[i], add_special_tokens=False)["input_ids"][0]
    for i in range(len(named_labels))
]

print(label_ids)
